{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-12-29T15:21:07.745948Z",
     "start_time": "2024-12-29T15:20:58.157667Z"
    }
   },
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "\n",
    "# 加载模型和 tokenizer\n",
    "def load_model_and_tokenizer(model_name, device=\"cuda\"):\n",
    "    print(f\"Loading model: {model_name}\")\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n",
    "    model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_name,\n",
    "        torch_dtype=torch.float16,\n",
    "        device_map=\"auto\",\n",
    "        attn_implementation=\"eager\",\n",
    "    )\n",
    "    return model, tokenizer"
   ],
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\fzkuj\\anaconda3\\envs\\nano\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-29T15:21:41.342346Z",
     "start_time": "2024-12-29T15:21:07.755001Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model_name = \"mistralai/Mistral-7B-Instruct-v0.1\"\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# 加载模型和 tokenizer\n",
    "model, tokenizer = load_model_and_tokenizer(model_name, device)\n",
    "\n",
    "# 加载和分词输入\n",
    "dataset = load_dataset(\"emozilla/pg19\", split=\"test\", trust_remote_code=True)\n",
    "text = dataset[0][\"text\"][:16384]  # 仅取前 16384 个字符\n",
    "input_ids = tokenizer(text, return_tensors=\"pt\", max_length=16384, truncation=True).input_ids\n"
   ],
   "id": "a68fa0440578ca16",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model: mistralai/Mistral-7B-Instruct-v0.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.56s/it]\n",
      "Some parameters are on the meta device because they were offloaded to the cpu.\n"
     ]
    }
   ],
   "execution_count": 2
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-29T15:21:41.856865Z",
     "start_time": "2024-12-29T15:21:41.851101Z"
    }
   },
   "cell_type": "code",
   "source": [
    "print(text[:100])\n",
    "print(input_ids[:100])"
   ],
   "id": "30ed786303dab03a",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ST. PAUL***\n",
      "\n",
      "\n",
      "E-text prepared by Josephine Paolucci and the Project Gutenberg Online\n",
      "Distributed Pro\n",
      "tensor([[    1,   920, 28723,  ..., 28725,  3364, 28705]])\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2024-12-29T11:07:06.823044Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 提取注意力分数\n",
    "# 模型前向\n",
    "with torch.no_grad():\n",
    "    output = model(input_ids.to(device), output_attentions=True)"
   ],
   "id": "241cf83bdbfc0246",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model: mistralai/Mistral-7B-Instruct-v0.1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\fzkuj\\anaconda3\\envs\\nano\\Lib\\site-packages\\accelerate\\utils\\modeling.py:1390: UserWarning: Current model requires 4096 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.\n",
      "  warnings.warn(\n",
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:24<00:00, 12.39s/it]\n",
      "Some parameters are on the meta device because they were offloaded to the cpu and disk.\n",
      "Using the latest cached version of the dataset since emozilla/pg19 couldn't be found on the Hugging Face Hub\n",
      "Found the latest cached dataset configuration 'default' at C:\\Users\\fzkuj\\.cache\\huggingface\\datasets\\emozilla___pg19\\default\\0.0.0\\c021754c8e01c5b1cc83a1f549c1f97fbbb756b8 (last modified on Sat Dec 28 15:02:36 2024).\n",
      "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation=\"eager\"` when loading the model.\n"
     ]
    }
   ],
   "execution_count": null
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-12-29T11:05:11.455421Z",
     "start_time": "2024-12-29T11:05:11.450641Z"
    }
   },
   "cell_type": "code",
   "source": "print(output.keys())",
   "id": "a4176d25b06cb8a",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33\n"
     ]
    }
   ],
   "execution_count": 4
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "output[\"attentions\"]",
   "id": "75b41a4cd84f22df"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "\n",
    "\n",
    "# 可视化某一层的注意力分数\n",
    "def visualize_attention_scores(attention_scores, layer_idx, head_idx, seq_len, output_file=None):\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    attention_matrix = attention_scores[layer_idx][0, head_idx, :seq_len, :seq_len]\n",
    "    plt.imshow(attention_matrix, cmap=\"viridis\", aspect=\"auto\")\n",
    "    plt.colorbar(label=\"Attention Score\")\n",
    "    plt.title(f\"Attention Layer {layer_idx + 1}, Head {head_idx + 1}\")\n",
    "    plt.xlabel(\"Key Position\")\n",
    "    plt.ylabel(\"Query Position\")\n",
    "    if output_file:\n",
    "        plt.savefig(output_file)\n",
    "    plt.show()\n"
   ],
   "id": "4ad947e8d99977da"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "\n",
    "attention_scores = extract_attention_scores(model, input_ids, device)\n",
    "\n",
    "# 保存或可视化\n",
    "layer_idx = 0  # 可视化第几层\n",
    "head_idx = 0   # 可视化第几个注意力头\n",
    "seq_len = 1024  # 只显示前 1024 长度的分数（避免图太大）\n",
    "visualize_attention_scores(attention_scores, layer_idx, head_idx, seq_len, output_file=\"attention_layer_1_head_1.png\")\n",
    "\n",
    "# 保存完整注意力分数到文件\n",
    "torch.save(attention_scores, \"attention_scores.pt\")\n",
    "print(\"Attention scores saved to 'attention_scores.pt'\")"
   ],
   "id": "add03ca42b8c221"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T07:36:53.393392Z",
     "start_time": "2025-01-15T07:36:52.602134Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "def get_alibi_slope(num_heads):\n",
    "    x = (2 ** 8) ** (1 / num_heads)\n",
    "    return (\n",
    "        torch.tensor([1 / x ** (i + 1) for i in range(int(num_heads))])\n",
    "    )\n",
    "\n",
    "a = get_alibi_slope(36)\n",
    "print(a)"
   ],
   "id": "a36895786a9802b7",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.8572, 0.7349, 0.6300, 0.5400, 0.4629, 0.3969, 0.3402, 0.2916, 0.2500,\n",
      "        0.2143, 0.1837, 0.1575, 0.1350, 0.1157, 0.0992, 0.0850, 0.0729, 0.0625,\n",
      "        0.0536, 0.0459, 0.0394, 0.0338, 0.0289, 0.0248, 0.0213, 0.0182, 0.0156,\n",
      "        0.0134, 0.0115, 0.0098, 0.0084, 0.0072, 0.0062, 0.0053, 0.0046, 0.0039])\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T15:24:17.999284Z",
     "start_time": "2025-01-15T15:24:17.264025Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import torch\n",
    "\n",
    "def get_alibi_slope(num_heads):\n",
    "    x = (2 ** 16) ** (1 / num_heads)\n",
    "    # 生成正值的一半 slopes\n",
    "    half_slopes = torch.tensor([1 / x ** (i + 1) for i in range(int(num_heads / 2))])\n",
    "    # 拼接正值和负值，形成对称 Tensor\n",
    "    full_slopes = torch.cat([half_slopes, -half_slopes.flip(0)])\n",
    "    return full_slopes\n",
    "\n",
    "# 测试\n",
    "a = get_alibi_slope(12)\n",
    "print(a)"
   ],
   "id": "847f7bab9b990751",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.3969,  0.1575,  0.0625,  0.0248,  0.0098,  0.0039, -0.0039, -0.0098,\n",
      "        -0.0248, -0.0625, -0.1575, -0.3969])\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-01-15T16:18:41.263534Z",
     "start_time": "2025-01-15T16:18:41.259202Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_alibi_slope(num_heads):\n",
    "    head = int(num_heads * 0.5)\n",
    "\n",
    "    x = (2 ** 8) ** (1 / num_heads)\n",
    "    # 生成正值的一半 slopes\n",
    "    pos_slopes = torch.tensor([1 / x ** (i + 1) for i in range(head)])\n",
    "\n",
    "    x = (2 ** 8) ** (1 / head)\n",
    "    # 生成负值的一半 slopes\n",
    "    neg_slopes = torch.tensor([1 / x ** (i + 2) for i in range(head)])\n",
    "    # 拼接正值和负值，形成对称 Tensor\n",
    "    full_slopes = torch.cat([pos_slopes, -neg_slopes.flip(0)])\n",
    "    return full_slopes\n",
    "\n",
    "# 测试\n",
    "a = get_alibi_slope(12)\n",
    "print(a)"
   ],
   "id": "c7a89ab646f2b1f6",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 0.6300,  0.3969,  0.2500,  0.1575,  0.0992,  0.0625, -0.0016, -0.0039,\n",
      "        -0.0098, -0.0248, -0.0625, -0.1575])\n"
     ]
    }
   ],
   "execution_count": 17
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "3da45cf610132ce1"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
